/*
 * Digest Authentication Module
 *
 * Copyright (C) 2001-2003 FhG Fokus
 *
 * This file is part of ser, a free SIP server.
 *
 * ser is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version
 *
 * For a license to use the ser software under conditions
 * other than those described here, or to purchase support for this
 * software, please contact iptel.org by e-mail at the following addresses:
 *    info@iptel.org
 *
 * ser is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
 *
 */

#include <stdio.h>
#include <string.h>
#include "../../core/sr_module.h"
#include "../../lib/srdb2/db.h"
#include "../../core/dprint.h"
#include "../../core/error.h"
#include "../../core/mem/mem.h"
#include "../../modules/auth/api.h"
#include "authorize.h"
#include "aaa_avps.h"
#include "uid_auth_db_mod.h"

MODULE_VERSION

#define TABLE_VERSION 7

/*
 * Module destroy function prototype
 */
static void destroy(void);


/*
 * Module child-init function prototype
 */
static int child_init(int rank);


/*
 * Module initialization function prototype
 */
static int mod_init(void);


static int authdb_fixup(void **param, int param_no);


#define USERNAME_COL "auth_username"
#define DID_COL "did"
#define REALM_COL "realm"
#define PASS_COL "ha1"
#define PASS_COL_2 "ha1b"
#define PLAIN_PASS_COL "password"
#define DEFAULT_CRED_LIST "uid"
#define FLAGS_COL "flags"

/*
 * Module parameter variables
 */
static char *db_url = DEFAULT_RODB_URL;

str username_column = STR_STATIC_INIT(USERNAME_COL);
str did_column = STR_STATIC_INIT(DID_COL);
str realm_column = STR_STATIC_INIT(REALM_COL);
str pass_column = STR_STATIC_INIT(PASS_COL);
str pass_column_2 = STR_STATIC_INIT(PASS_COL_2);
str flags_column = STR_STATIC_INIT(FLAGS_COL);
str plain_password_column = STR_STATIC_INIT(PLAIN_PASS_COL);

int calc_ha1 = 0;
int use_did = 0;
int check_all = 0;

db_ctx_t *auth_db_handle = 0; /* database connection handle */
auth_api_s_t auth_api;

str credentials_list = STR_STATIC_INIT(DEFAULT_CRED_LIST);

str *credentials;  /* Parsed list of credentials to load */
int credentials_n; /* Number of credentials in the list */


/*
 * Exported functions
 */
static cmd_export_t cmds[] = {{"www_authenticate", www_authenticate, 2,
									  authdb_fixup, 0, REQUEST_ROUTE},
		{"www_authorize", www_authenticate, 2, authdb_fixup, 0, REQUEST_ROUTE},
		{"proxy_authenticate", proxy_authenticate, 2, authdb_fixup, 0,
				REQUEST_ROUTE},
		{"proxy_authorize", proxy_authenticate, 2, authdb_fixup, 0,
				REQUEST_ROUTE},
		{0, 0, 0, 0, 0, 0}};


/*
 * Exported parameters
 */
static param_export_t params[] = {{"db_url", PARAM_STRING, &db_url},
		{"username_column", PARAM_STR, &username_column},
		{"did_column", PARAM_STR, &did_column},
		{"realm_column", PARAM_STR, &realm_column},
		{"password_column", PARAM_STR, &pass_column},
		{"password_column_2", PARAM_STR, &pass_column_2},
		{"plain_password_column", PARAM_STR, &plain_password_column},
		{"flags_column", PARAM_STR, &flags_column},
		{"calculate_ha1", PARAM_INT, &calc_ha1},
		{"load_credentials", PARAM_STR, &credentials_list},
		{"use_did", PARAM_INT, &use_did},
		{"check_all_ha1", PARAM_INT, &check_all}, {0, 0, 0}};


/*
 * Module interface
 */
struct module_exports exports = {
		"uid_auth_db",	 /* module name */
		DEFAULT_DLFLAGS, /* dlopen flags */
		cmds,			 /* exported functions */
		params,			 /* exported parameters */
		0,				 /* exported RPC methods */
		0,				 /* exported pseudo-variables */
		0,				 /* response function */
		mod_init,		 /* module init function */
		child_init,		 /* child init function */
		destroy			 /* destroy function */
};

static authdb_table_info_t *registered_tables = NULL;

static int generate_queries(authdb_table_info_t *info)
{
	db_fld_t match_with_did[] = {{.name = username_column.s, .type = DB_STR},
			{.name = realm_column.s, .type = DB_STR},
			{.name = did_column.s, .type = DB_STR}, {.name = NULL}};
	db_fld_t match_without_did[] = {{.name = username_column.s, .type = DB_STR},
			{.name = realm_column.s, .type = DB_STR}, {.name = NULL}};
	db_fld_t *result_cols = NULL;
	int len, i;

	len = sizeof(*result_cols) * (credentials_n + 3);
	result_cols = pkg_malloc(len);
	if(!result_cols) {
		PKG_MEM_ERROR;
		return -1;
	}
	memset(result_cols, 0, len);

	result_cols[0].name = pass_column.s;
	result_cols[0].type = DB_CSTR;

	result_cols[1].name = flags_column.s;
	result_cols[1].type = DB_INT;
	for(i = 0; i < credentials_n; i++) {
		result_cols[2 + i].name = credentials[i].s;
		result_cols[2 + i].type = DB_STR;
	}
	result_cols[2 + i].name = NULL;

	if(use_did) {
		info->query_pass = db_cmd(DB_GET, auth_db_handle, info->table.s,
				result_cols, match_with_did, NULL);
		result_cols[0].name = pass_column_2.s;
		info->query_pass2 = db_cmd(DB_GET, auth_db_handle, info->table.s,
				result_cols, match_with_did, NULL);
		result_cols[0].name = plain_password_column.s;
		info->query_password = db_cmd(DB_GET, auth_db_handle, info->table.s,
				result_cols, match_with_did, NULL);
	} else {
		info->query_pass = db_cmd(DB_GET, auth_db_handle, info->table.s,
				result_cols, match_without_did, NULL);
		result_cols[0].name = pass_column_2.s;
		info->query_pass2 = db_cmd(DB_GET, auth_db_handle, info->table.s,
				result_cols, match_without_did, NULL);
		result_cols[0].name = plain_password_column.s;
		info->query_password = db_cmd(DB_GET, auth_db_handle, info->table.s,
				result_cols, match_without_did, NULL);
	}

	pkg_free(result_cols);
	if(info->query_pass && info->query_pass2 && info->query_password)
		return 0;
	else
		return -1;
}

static int child_init(int rank)
{
	authdb_table_info_t *i;

	if(rank == PROC_INIT || rank == PROC_MAIN || rank == PROC_TCP_MAIN)
		return 0; /* do nothing for the main process */

	auth_db_handle = db_ctx("auth_db");
	if(!auth_db_handle)
		goto err;
	if(db_add_db(auth_db_handle, db_url) < 0)
		goto err;
	if(db_connect(auth_db_handle) < 0)
		goto err;

	/* initializing queries */
	i = registered_tables;
	while(i) {
		if(generate_queries(i) < 0) {
			LM_ERR("can't prepare queries\n");
			return -1;
		}
		i = i->next;
	}

	return 0;

err:

	if(auth_db_handle) {
		auth_db_handle = NULL;
		db_ctx_free(auth_db_handle);
	}

	LM_ERR("Error while initializing database layer\n");
	return -1;
}


static int mod_init(void)
{
	bind_auth_s_t bind_auth;

	LM_DBG("auth_db module - initializing\n");

	bind_auth = (bind_auth_s_t)find_export("bind_auth_s", 0, 0);
	if(!bind_auth) {
		LM_ERR("Unable to find bind_auth function\n");
		return -1;
	}
	if(bind_auth(&auth_api) < 0) {
		LM_ERR("Unable to bind auth module\n");
		return -3;
	}

	if(aaa_avps_init(&credentials_list, &credentials, &credentials_n)) {
		return -1;
	}

	return 0;
}


static void destroy(void)
{
	if(auth_db_handle) {
		db_ctx_free(auth_db_handle);
		auth_db_handle = NULL;
	}
}

static int str_case_equals(const str *a, const str *b)
{
	/* ugly hack: taken from libcds */
	int i;

	if(!a) {
		if(!b)
			return 0;
		else
			return (b->len == 0) ? 0 : 1;
	}
	if(!b)
		return (a->len == 0) ? 0 : 1;
	if(a->len != b->len)
		return 1;

	for(i = 0; i < a->len; i++)
		if(a->s[i] != b->s[i])
			return 1;
	return 0;
}

static authdb_table_info_t *find_table_info(str *table)
{
	authdb_table_info_t *i = registered_tables;

	/* sequential search is OK because it is called only in child init */
	while(i) {
		if(str_case_equals(&i->table, table) == 0)
			return i;
		i = i->next;
	}
	return NULL;
}

static authdb_table_info_t *register_table(str *table)
{
	authdb_table_info_t *info;

	info = find_table_info(table);
	if(info)
		return info; /* queries for this table already exist */

	info = (authdb_table_info_t *)pkg_malloc(
			sizeof(authdb_table_info_t) + table->len + 1);
	if(!info) {
		PKG_MEM_ERROR;
		return NULL;
	}

	info->table.s = info->buf;
	info->table.len = table->len;
	memcpy(info->table.s, table->s, table->len);
	info->table.s[table->len] = 0;

	/* append to the beginning (we don't care about order) */
	info->next = registered_tables;
	registered_tables = info;

	return info;
}
/*
 * Convert char* parameter to str* parameter
 */
static int authdb_fixup(void **param, int param_no)
{
	fparam_t *p;

	if(param_no == 1) {
		return fixup_var_str_12(param, param_no);
	} else if(param_no == 2) {
		if(fixup_var_str_12(param, param_no) < 0)
			return -1;
		p = (fparam_t *)(*param);
		if(p->type == FPARAM_STR) {
			*param = register_table(&p->v.str);
			if(!*param) {
				ERR("can't register table %.*s\n", p->v.str.len, p->v.str.s);
				return -1;
			}
		} else {
			LM_ERR("Non-string value of table with credentials"
				   " is not allowed.\n");
			/* TODO: allow this too */
			return -1;
		}
	}

	return 0;
}
